{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "145d7c1f-fc9d-4ae3-9e22-5cdb8a2d8dfa",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-06-10T02:17:41.830460Z",
     "iopub.status.busy": "2024-06-10T02:17:41.829883Z",
     "iopub.status.idle": "2024-06-10T02:17:41.839926Z",
     "shell.execute_reply": "2024-06-10T02:17:41.838229Z",
     "shell.execute_reply.started": "2024-06-10T02:17:41.830414Z"
    }
   },
   "outputs": [],
   "source": [
    "from IPython.display import Image"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7a8884bb-d4f8-4ec7-88a9-c74afd63d947",
   "metadata": {},
   "source": [
    "- Scaling and evaluating sparse autoencoders\n",
    "    - https://cdn.openai.com/papers/sparse-autoencoders.pdf\n",
    "    - https://github.com/openai/sparse_autoencoder\n",
    "    - https://openai.com/index/extracting-concepts-from-gpt-4/\n",
    "    - claude：\n",
    "        - https://transformer-circuits.pub/2023/monosemantic-features\n",
    "            - Using a sparse autoencoder, we extract a large number of interpretable features from a one-layer transformer.\n",
    "        - https://www.anthropic.com/news/decomposing-language-models-into-understandable-components\n",
    "        - For the first time, we feel that the next primary obstacle to interpreting large language models is engineering rather than science.\n",
    "- SAE（Sparse AutoEncoder）\n",
    "    - unsupervised approach for extracting interpretable features from a language model by **reconstructing activations** from a sparse bottleneck layer.\n",
    "    - Since language models learn **many concepts**, autoencoders need to be very large to recover all relevant features.\n",
    "        - Scaling SAEs\n",
    "    - **k-sparse autoencoders** [Makhzani and Frey, 2013] to directly control sparsity,\n",
    "        - https://arxiv.org/pdf/1312.5663\n",
    "    - we train a 16 million latent autoencoder on GPT-4 activations for 40 billion tokens.\n",
    "        - $16,000,000$\n",
    "        - the sparse autoencoder that supports GPT-4 was able to find 16 million features of GPT-4. \n",
    "- 神经网络的可解释性（to understand how neural networks work and think,）\n",
    "    - 'features' rather than neuron units.\n",
    "        - 'features that respond to legal texts'\n",
    "        - 'features that respond to DNA sequences,'\n",
    "    - When a large-scale language model generates each token in a sentence, only **a small part of the huge neural network fires** (sends a signal).\n",
    "- https://www.oxen.ai/blog/arxiv-dives"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "0e9bf025-ad87-43a2-aae3-af6d01ce4624",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-06-10T04:04:25.037574Z",
     "iopub.status.busy": "2024-06-10T04:04:25.036927Z",
     "iopub.status.idle": "2024-06-10T04:04:25.049256Z",
     "shell.execute_reply": "2024-06-10T04:04:25.047120Z",
     "shell.execute_reply.started": "2024-06-10T04:04:25.037528Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<img src=\"../../../imgs/claude-sae.png\" width=\"500\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Image(url='../../../imgs/claude-sae.png', width=500)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6d1cace1-9e3b-4f32-bcbf-7782b28c9ec3",
   "metadata": {},
   "source": [
    "\n",
    "- The first layer (“encoder”) maps the activity to a higher-dimensional layer via a learned linear transformation followed by a ReLU nonlinearity.\n",
    "    - 512 -> 131072 (256x)\n",
    "    - They refer to the units of this high-dimensional layer as “features.”\n",
    "- The second layer (“decoder”) attempts to reconstruct the model activations via a linear transformation of the feature activations.\n",
    "    - 131072 -> 512 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "09acac0a-db48-4142-a9ef-8ad6068e5c02",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-06-10T02:27:50.149188Z",
     "iopub.status.busy": "2024-06-10T02:27:50.148564Z",
     "iopub.status.idle": "2024-06-10T02:27:50.161477Z",
     "shell.execute_reply": "2024-06-10T02:27:50.159322Z",
     "shell.execute_reply.started": "2024-06-10T02:27:50.149142Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<img src=\"https://images.ctfassets.net/kftzwdyauwt9/5i0GAmvivjtoiLsTnJW6HS/c874efa090da2a90280be59fc424f2b4/sparse-autoencoder_dark.gif?w=3840&q=80&fm=webp\" width=\"400\"/>"
      ],
      "text/plain": [
       "<IPython.core.display.Image object>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# top4\n",
    "Image(url='https://images.ctfassets.net/kftzwdyauwt9/5i0GAmvivjtoiLsTnJW6HS/c874efa090da2a90280be59fc424f2b4/sparse-autoencoder_dark.gif?w=3840&q=80&fm=webp', \n",
    "      width=400)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a258278c-b5b1-4260-84b5-7b4ebef74787",
   "metadata": {},
   "source": [
    "## SAE"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58bc9295-6540-4afb-a45b-20e4c896c808",
   "metadata": {},
   "source": [
    "\n",
    "- $d, n, k$:\n",
    "    - d: input space\n",
    "    - n: latent(feature) space\n",
    "    - k: sparsity\n",
    "\n",
    "$$\n",
    "\\begin{align*}\n",
    "z &= \\text{ReLU}(W_{\\text{enc}}(x - b_{\\text{pre}}) + b_{\\text{enc}}) \\\\\n",
    "\\hat{x} &= W_{\\text{dec}}z + b_{\\text{pre}}\n",
    "\\end{align*}\n",
    "$$\n",
    "- $b_\\text{pre}$ 是在输入向量 $x$ 进行编码之前，从 𝑥 中减去的一个常数偏置。这种操作的目的是将数据中心化，\n",
    "\n",
    "- topK\n",
    "\n",
    "    $$\n",
    "    z = \\text{TopK}(W_{\\text{enc}}(x - b_{\\text{pre}}))\n",
    "    $$\n",
    "  -  We use a k-sparse autoencoder [Makhzani and Frey, 2013], which directly controls the number of active latents by using an **activation function (TopK)** that only keeps the k largest latents, zeroing the rest.\n",
    "- training loss\n",
    "    \n",
    "    $$\n",
    "    \\mathcal L=\\|x-\\hat x\\|^2_2\n",
    "    $$ \n",
    "\n",
    "    - It removes the need for the L1 penalty\n",
    "    - Calude: L2 reconstruction + L1 on hidden layer activation\n",
    "- Jointly fitting sparsity ($L(N, K)$)\n",
    "    - the number of latents $n$ and the sparsity level $k$\n",
    "\n",
    "$$\n",
    "\\begin{equation}\n",
    "L(n, k) = \\exp(\\alpha + \\beta_k \\log(k) + \\beta_n \\log(n) + \\gamma \\log(k) \\log(n)) + \\exp(\\zeta + \\eta \\log(k))\n",
    "\\end{equation}\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1bcedafa-50d6-4403-baed-df0292571ac3",
   "metadata": {},
   "source": [
    "## topK backward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "49c682b5-f26d-49c0-bd64-18152faba5a5",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-06-10T04:56:10.341092Z",
     "iopub.status.busy": "2024-06-10T04:56:10.340455Z",
     "iopub.status.idle": "2024-06-10T04:56:10.354015Z",
     "shell.execute_reply": "2024-06-10T04:56:10.351897Z",
     "shell.execute_reply.started": "2024-06-10T04:56:10.341045Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7657b0751550>"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "torch.manual_seed(42)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a46f0e4f-2acd-4b3d-886f-5be284986624",
   "metadata": {},
   "source": [
    "$$\n",
    "\\begin{split}\n",
    "y_1=W_1\\cdot x_1\\\\\n",
    "y_k=\\text{TopK}(y_1)\\\\\n",
    "\\ell=\\sum y_k^2\n",
    "\\end{split}\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "343d1e7e-af5c-4ded-ac8d-bf8c55e44471",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-06-10T04:56:12.033206Z",
     "iopub.status.busy": "2024-06-10T04:56:12.032626Z",
     "iopub.status.idle": "2024-06-10T04:56:12.049087Z",
     "shell.execute_reply": "2024-06-10T04:56:12.047091Z",
     "shell.execute_reply.started": "2024-06-10T04:56:12.033163Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([2.1621, 2.6728, 1.9516, 0.8855, 2.4250, 2.3444], grad_fn=<MvBackward0>)\n",
      "tensor([2.6728, 2.4250, 2.3444], grad_fn=<TopkBackward0>)\n",
      "tensor([1, 4, 5])\n"
     ]
    }
   ],
   "source": [
    "# 6\n",
    "x1 = torch.rand(6, requires_grad = True)\n",
    "\n",
    "# 6*6\n",
    "W1 = torch.rand(6, 6, requires_grad = True)\n",
    "\n",
    "# 6\n",
    "y1 = W1 @ x1\n",
    "\n",
    "# 3\n",
    "yk, indices = torch.topk(y1, 3)\n",
    "\n",
    "print(y1,)\n",
    "print(yk,)\n",
    "print(indices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "5ab23965-8eac-458d-8426-ebfe913ffa8c",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-06-10T04:56:28.747598Z",
     "iopub.status.busy": "2024-06-10T04:56:28.746961Z",
     "iopub.status.idle": "2024-06-10T04:56:28.763899Z",
     "shell.execute_reply": "2024-06-10T04:56:28.761790Z",
     "shell.execute_reply.started": "2024-06-10T04:56:28.747553Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
       "        [4.7162, 4.8912, 2.0466, 5.1280, 2.0871, 3.2121],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
       "        [4.2791, 4.4378, 1.8569, 4.6527, 1.8937, 2.9144],\n",
       "        [4.1368, 4.2903, 1.7952, 4.4980, 1.8308, 2.8175]])"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# to scalar\n",
    "loss1 = (yk ** 2).sum()\n",
    "\n",
    "# topk operation is differential\n",
    "# grad_fn=<TopkBackward0>\n",
    "loss1.backward()\n",
    "W1.grad"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8d9cf024-a907-476a-9779-1426b4abec31",
   "metadata": {},
   "source": [
    "$$\n",
    "\\begin{split}\n",
    "&y_1=W_1 x_1\\\\\n",
    "&y_k=\\text{TopK}(y_1)=W_ky_1\\\\\n",
    "&\\ell=\\sum y_k^2\n",
    "\\end{split}\n",
    "$$\n",
    "\n",
    "$$\n",
    "\\frac{\\partial \\ell}{\\partial W_1} = \\frac{\\partial \\ell}{\\partial y_k} \\cdot \\frac{\\partial y_k}{\\partial y_1} \\cdot \\frac{\\partial y_1}{\\partial W_1} = 2 y_k \\cdot W_k \\cdot x_1\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "772290e7-f12f-436c-bc84-835c80a4dc17",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-06-10T04:31:57.092999Z",
     "iopub.status.busy": "2024-06-10T04:31:57.092377Z",
     "iopub.status.idle": "2024-06-10T04:31:57.115180Z",
     "shell.execute_reply": "2024-06-10T04:31:57.112819Z",
     "shell.execute_reply.started": "2024-06-10T04:31:57.092953Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0., 1., 0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0., 1., 0.],\n",
       "        [0., 0., 0., 0., 0., 1.]])"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "k = 3\n",
    "Wk = torch.zeros(k, y1.shape[0])\n",
    "Wk[torch.arange(k), indices] = 1\n",
    "Wk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "104aa19d-c3b4-4aaf-b06a-c3b04d9ddcd5",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-06-10T04:35:04.160663Z",
     "iopub.status.busy": "2024-06-10T04:35:04.160049Z",
     "iopub.status.idle": "2024-06-10T04:35:04.185262Z",
     "shell.execute_reply": "2024-06-10T04:35:04.182992Z",
     "shell.execute_reply.started": "2024-06-10T04:35:04.160619Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
       "        [4.7162, 4.8912, 2.0466, 5.1280, 2.0871, 3.2121],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
       "        [4.2791, 4.4378, 1.8569, 4.6527, 1.8937, 2.9144],\n",
       "        [4.1368, 4.2903, 1.7952, 4.4980, 1.8308, 2.8175]],\n",
       "       grad_fn=<MmBackward0>)"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(2*yk @ Wk).reshape(-1, 1) @ x1.reshape(1, 6)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
